import collections
import datetime
import math

import numpy as np
import pandas as pd
import pyecharts
from pyecharts import options as opts
from pyecharts import globals as pyecharts_globals
from pyecharts.commons.utils import JsCode

LONG_ENTRY = "image://"
SHORT_ENTRY = "image://"
PROFIT_CLOSE = "image://"
LOSS_CLOSE = "image://"


class CustomDatetime(datetime.datetime):
    def __new__(cls, *args, **kwargs):
        return datetime.datetime.__new__(cls, *args, **kwargs)

    @staticmethod
    def from_datetime(original_datetime: datetime.datetime):
        return CustomDatetime.__new__(
            CustomDatetime,
            original_datetime.year,
            original_datetime.month,
            original_datetime.day,
            original_datetime.hour,
            original_datetime.minute,
            original_datetime.second,
            original_datetime.microsecond,
            original_datetime.tzinfo,
            fold=original_datetime.fold,
        )

    @staticmethod
    def strings(dates_list):
        return [d.__str__() for d in dates_list]

    @classmethod
    def from_timestamp(cls, timestamp):
        return CustomDatetime.from_datetime(datetime.datetime.fromtimestamp(timestamp / 1000))

    @classmethod
    def interval_to_timedelta(cls, interval_string):
        if interval_string[-1] == "m":
            return datetime.timedelta(minutes=int(interval_string[:-1]))
        elif interval_string[-1] == "h":
            return datetime.timedelta(hours=int(interval_string[:-1]))
        elif interval_string[-1] == "d":
            return datetime.timedelta(days=int(interval_string[:-1]))
        else:
            raise ValueError(f"unsupported candles_interval format '{interval_string}'")

    def get_minute_rounded(self):
        return CustomDatetime.from_datetime(
            self - datetime.timedelta(seconds=self.second, microseconds=self.microsecond)
        )

    def __str__(self):
        return self.strftime("%Y-%m-%d\n%H:%M")


def dump_interactive_plot(
    config: collections.OrderedDict,
    data: np.ndarray,
    longs: pd.DataFrame,
    shorts: pd.DataFrame,
    candles_interval=None,
    theme="",
):
    if candles_interval is None:
        candles_interval = config.get("plot_candles_interval", datetime.timedelta(minutes=1))
    if type(candles_interval) is str:
        candles_interval = CustomDatetime.interval_to_timedelta(candles_interval)

    if theme == "":
        theme = config.get("plot_theme", pyecharts_globals.ThemeType.INFOGRAPHIC)

    # Creating graph
    candlesticks = create_graphs(data, candles_interval, config["ohlcv"])
    long_entries, long_profits, long_losses = create_positions(longs, True)
    short_entries, short_profits, short_losses = create_positions(shorts, False)
    scatters = [long_entries, long_profits, long_losses, short_entries, short_profits, short_losses]

    grid_chart = pyecharts.charts.Grid(
        init_opts=opts.InitOpts(
            width="100%",
            height="1000px",
            animation_opts=opts.AnimationOpts(animation=True),
            page_title=f"{config['symbol']} : {config['start_date']} to {config['end_date']}",
            theme=theme,
        )
    )

    for scatter in scatters:
        candlesticks.overlap(scatter)

    grid_chart.add(
        candlesticks,
        grid_opts=opts.GridOpts(pos_left="10%", pos_right="8%", height="80%"),
    )

    grid_chart.render(config["plots_dirpath"] + "interactive_plot.html")


def create_graphs(data, candles_interval, is_ohlcv=True):
    # allocating candles data
    candles_date = np.empty((len(data),), dtype=datetime.datetime)
    candles_data = np.empty((len(data),), dtype=object)

    # Setting first values
    current_date = CustomDatetime.from_timestamp(data[0][0]).get_minute_rounded()
    candles_date[0] = current_date
    if is_ohlcv:  # hlc format
        if len(data[0]) < 4:
            raise IOError("Backtest ohlcv data format seems to be invalid.")

        first_open = data[0][3]  # We set the first open to the first close
        # because the actual open is not given by the data
        first_high = data[0][1]
        first_low = data[0][2]
        first_close = data[0][3]
    elif len(data[0]) >= 3:  # ticks format
        # We only have 1 price per tick
        first_open = first_high = first_low = first_close = data[0][2]
    else:
        raise IOError("Backtest data format seems to be invalid.")
    # The graph expects the data to be in open,close,low,high format
    candles_data[0] = [first_open, first_close, first_low, first_high]

    # Creating candles
    next_date = current_date + candles_interval
    candles_index = 0

    for data_row in data:
        current_date = CustomDatetime.from_timestamp(data_row[0])

        if is_ohlcv:
            high = data_row[1]
            low = data_row[2]
            close = data_row[3]
        else:  # ticks format
            high = low = close = data_row[2]

        if current_date >= next_date:
            # New candle
            current_date = current_date.get_minute_rounded()
            next_date = current_date + candles_interval

            candles_index += 1
            candles_date[candles_index] = current_date
            candles_data[candles_index] = [candles_data[candles_index - 1][1], close, low, high]

        # Update current candle
        candles_data[candles_index][1] = close
        candles_data[candles_index][2] = min(candles_data[candles_index][2], low)
        candles_data[candles_index][3] = max(candles_data[candles_index][3], high)

    candlesticks = (
        pyecharts.charts.Candlestick()
        .add_xaxis(xaxis_data=CustomDatetime.strings(candles_date[:candles_index]))
        .add_yaxis(
            series_name="Candlesticks",
            y_axis=list(candles_data[:candles_index]),
            itemstyle_opts=opts.ItemStyleOpts(
                color0="#ef232a",
                color="#14b143",
                border_color0="#ef232a",
                border_color="#14b143",
            ),
        )
        .set_global_opts(
            xaxis_opts=opts.AxisOpts(
                is_show=True,
                is_scale=True,
            ),
            yaxis_opts=opts.AxisOpts(
                is_scale=True,
                is_show=True,
                splitarea_opts=opts.SplitAreaOpts(
                    is_show=True, areastyle_opts=opts.AreaStyleOpts(opacity=1)
                ),
            ),
            datazoom_opts=[
                opts.DataZoomOpts(
                    type_="inside",
                    is_show=False,
                    range_start=0,
                    range_end=100,
                ),
                opts.DataZoomOpts(
                    type_="slider",
                    is_show=True,
                    range_start=0,
                    range_end=100,
                ),
            ],
        )
    )
    return candlesticks


def create_positions(fills: pd.DataFrame, long: bool):
    entries_timestamps = []
    entries_prices = []
    entries_we = []

    profits_timestamps = []
    profits_prices = []
    profits_we = []
    profits_pnl = []

    losses_timestamps = []
    losses_prices = []
    losses_we = []
    losses_pnl = []

    position_backgrounds = []
    first_entry = None

    for index, fill in fills.iterrows():
        timestamp = CustomDatetime.from_timestamp(fill["timestamp"])
        we = "%.02f%%" % (fill["wallet_exposure"] * 100,)

        if "entry" in fill["type"]:
            entries_timestamps.append(timestamp)
            entries_prices.append(fill["price"])
            entries_we.append(we)
            if first_entry is None:
                first_entry = (timestamp, fill["price"])

        elif "close" in fill["type"]:
            pnl = "%.02f%%" % (fill["pnl"],)
            if fill["pnl"] >= 0:
                profits_timestamps.append(timestamp)
                profits_prices.append(fill["price"])
                profits_we.append(we)
                profits_pnl.append(pnl)
            else:
                losses_timestamps.append(timestamp)
                losses_prices.append(fill["price"])
                losses_we.append(we)
                losses_pnl.append(pnl)

            if math.isclose(fill["wallet_exposure"], 0) and first_entry:
                background = dict(
                    x0=first_entry[0],
                    x1=timestamp,
                    fillcolor="green" if long == (fill["price"] >= fill["pprice"]) else "red",
                )
                position_backgrounds.append(background)
                first_entry = None
    scatters_global_options = dict(
        xaxis_opts=opts.AxisOpts(
            is_show=False,
            is_scale=True,
        ),
        yaxis_opts=opts.AxisOpts(
            is_scale=True,
            is_show=False,
        ),
        datazoom_opts=[
            opts.DataZoomOpts(
                type_="inside",
                is_show=False,
                range_start=0,
                range_end=100,
            ),
            opts.DataZoomOpts(
                type_="slider",
                is_show=True,
                range_start=0,
                range_end=100,
            ),
        ],
    )
    symbol_size = 15

    entries = (
        pyecharts.charts.Scatter()
        .add_xaxis(CustomDatetime.strings(entries_timestamps))
        .add_yaxis(
            ("Long" if long else "Short") + " Positions",
            y_axis=[(price, we) for price, we in zip(entries_prices, entries_we)],
            encode=dict(x=0, y=1),
            symbol=LONG_ENTRY if long else SHORT_ENTRY,
            symbol_size=symbol_size,
            label_opts=opts.LabelOpts(
                color="#2c2f30",
                formatter=JsCode(
                    """function (params) {
                                return 'WE ' + params.value[2];
                               }
                           """
                ),
            ),
            tooltip_opts=opts.TooltipOpts(
                formatter=JsCode(
                    """function (params) {
                                return params.value[0] +
                                       '<br>Price: ' + params.value[1] +
                                       '<br>WE: ' + params.value[2];
                                }
                           """
                )
            ),
        )
        .add_dataset(entries_we)
        .set_global_opts(**scatters_global_options)
    )

    profits = (
        pyecharts.charts.Scatter()
        .add_xaxis(CustomDatetime.strings(profits_timestamps))
        .add_yaxis(
            ("Long" if long else "Short") + " Profits",
            y_axis=[
                (price, profit, we)
                for price, profit, we in zip(profits_prices, profits_pnl, profits_we)
            ],
            encode=dict(x=0, y=1),
            symbol=PROFIT_CLOSE,
            symbol_size=symbol_size,
            label_opts=opts.LabelOpts(
                color="#0b611c",
                formatter=JsCode(
                    """function (params) {
                                return '▲'+params.value[2];
                               }
                           """
                ),
            ),
            tooltip_opts=opts.TooltipOpts(
                formatter=JsCode(
                    """function (params) {
                                return params.value[0] +
                                       '<br>Price:  ' + params.value[1] +
                                       '<br>Profit: ' + params.value[2] +
                                       '<br>WE:     ' + params.value[3];
                                }
                           """
                )
            ),
        )
        .set_global_opts(**scatters_global_options)
    )

    losses = (
        pyecharts.charts.Scatter()
        .add_xaxis(CustomDatetime.strings(losses_timestamps))
        .add_yaxis(
            ("Long" if long else "Short") + " Losses",
            y_axis=[
                (price, loss, we) for price, loss, we in zip(losses_prices, losses_pnl, losses_we)
            ],
            encode=dict(x=0, y=1),
            symbol=LOSS_CLOSE,
            symbol_size=symbol_size,
            label_opts=opts.LabelOpts(
                color="#610b0b",
                formatter=JsCode(
                    """function (params) {
                                return '▼' + params.value[2];
                               }
                           """
                ),
            ),
            tooltip_opts=opts.TooltipOpts(
                formatter=JsCode(
                    """function (params) {
                                return params.value[0] +
                                       '<br>Price: ' + params.value[1] +
                                       '<br>Loss:  ' + params.value[2] +
                                       '<br>WE:    ' + params.value[3];
                                }
                           """
                )
            ),
        )
        .set_global_opts(**scatters_global_options)
    )

    return entries, profits, losses
